Single Image Super Resolution using SRGAN¶
In [ ]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid
from torchvision.transforms import ToTensor
import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
random.seed(42)
import warnings
warnings.filterwarnings("ignore")
In [ ]:
load_pretrained_models = False
n_epochs = 150
dataset_path = "data/dogs_vs_cats/cats/"
batch_size = 16
lr = 0.0001
b1 = 0.5
b2 = 0.999
decay_epoch = 100
n_cpu = 8
hr_height = 128 #256
hr_width = 128 #256
channels = 3
os.makedirs("outputs_gen", exist_ok=True)
os.makedirs("outputs_lr", exist_ok=True)
os.makedirs("outputs_hr", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)
cuda = torch.cuda.is_available()
hr_shape = (hr_height, hr_width)
In [ ]:
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
class ImageDataset(Dataset):
def __init__(self, files, hr_shape):
hr_height, hr_width = hr_shape
self.lr_transform = transforms.Compose(
[
transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.hr_transform = transforms.Compose(
[
transforms.Resize((hr_height, hr_height), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.files = files
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)]).convert('RGB')
img_lr = self.lr_transform(img)
img_hr = self.hr_transform(img)
return {"lr": img_lr, "hr": img_hr}
def __len__(self):
return len(self.files)
In [ ]:
# train_paths = glob.glob(dataset_path + "train/*.*")
# test_paths = glob.glob(dataset_path + "val/*.*")
train_paths, test_paths = train_test_split(sorted(glob.glob(dataset_path + "/*.*")), test_size=0.02, random_state=42)
train_dataloader = DataLoader(ImageDataset(train_paths, hr_shape=hr_shape), batch_size=batch_size, shuffle=True, num_workers=n_cpu)
test_dataloader = DataLoader(ImageDataset(test_paths, hr_shape=hr_shape), batch_size=1, shuffle=False, num_workers=n_cpu)
In [ ]:
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
def forward(self, img):
return self.feature_extractor(img)
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
)
def forward(self, x):
return x + self.conv_block(x)
In [ ]:
class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()
# First layer
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
# Residual blocks
res_blocks = []
for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)
# Second conv layer post residual blocks
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
# Upsampling layers
upsampling = []
for out_features in range(2):
upsampling += [
# nn.Upsample(scale_factor=2),
nn.Conv2d(64, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)
# Final output layer
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
return out
In [ ]:
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
In [ ]:
# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()
# Set feature extractor to inference mode
feature_extractor.eval()
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()
# Load pretrained models
if load_pretrained_models:
generator.load_state_dict(torch.load("saved_models/generator.pth"))
discriminator.load_state_dict(torch.load("saved_models/discriminator.pth"))
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
In [ ]:
train_gen_losses, train_disc_losses, train_counter = [], [], []
test_gen_losses, test_disc_losses = [], []
test_counter = [idx*len(train_dataloader.dataset) for idx in range(1, n_epochs+1)]
for epoch in range(n_epochs):
### Training
gen_loss, disc_loss = 0, 0
tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))
for batch_idx, imgs in enumerate(tqdm_bar):
generator.train(); discriminator.train()
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
### Train Generator
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
loss_G.backward()
optimizer_G.step()
### Train Discriminator
optimizer_D.zero_grad()
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
loss_D.backward()
optimizer_D.step()
gen_loss += loss_G.item()
train_gen_losses.append(loss_G.item())
disc_loss += loss_D.item()
train_disc_losses.append(loss_D.item())
train_counter.append(batch_idx*batch_size + imgs_lr.size(0) + epoch*len(train_dataloader.dataset))
tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
# Testing
gen_loss, disc_loss = 0, 0
tqdm_bar = tqdm(test_dataloader, desc=f'Testing Epoch {epoch} ', total=int(len(test_dataloader)))
for batch_idx, imgs in enumerate(tqdm_bar):
generator.eval(); discriminator.eval()
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
### Eval Generator
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
### Eval Discriminator
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
gen_loss += loss_G.item()
disc_loss += loss_D.item()
tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))
# Save image grid with upsampled inputs and SRGAN outputs
if random.uniform(0,1)<0.1:
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
save_image(img_grid, f"outputs/{batch_idx}_e{epoch}.png", normalize=False)
generator.eval()
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Generate a high-resolution image from low-resolution input
gen_hr = generator(imgs_lr)
save_image(imgs_hr, f"outputs_hr/{batch_idx}_e{epoch}.png", normalize=False)
save_image(imgs_lr, f"outputs_lr/{batch_idx}_e{epoch}.png", normalize=False)
save_image(gen_hr, f"outputs_gen/{batch_idx}_e{epoch}.png", normalize=False)
test_gen_losses.append(gen_loss/len(test_dataloader))
test_disc_losses.append(disc_loss/len(test_dataloader))
# Save model checkpoints
if np.argmin(test_gen_losses) == len(test_gen_losses)-1:
torch.save(generator.state_dict(), "saved_models/generator.pth")
torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")
Training Epoch 0 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 0 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 1 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 1 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 2 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 2 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 3 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 3 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 4 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 4 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 5 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 5 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 6 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 6 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 7 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 7 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 8 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 8 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 9 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 9 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 10 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 10 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 11 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 11 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 12 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 12 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 13 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 13 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 14 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 14 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 15 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 15 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 16 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 16 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 17 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 17 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 18 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 18 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 19 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 19 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 20 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 20 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 21 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 21 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 22 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 22 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 23 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 23 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 24 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 24 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 25 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 25 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 26 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 26 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 27 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 27 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 28 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 28 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 29 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 29 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 30 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 30 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 31 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 31 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 32 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 32 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 33 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 33 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 34 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 34 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 35 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 35 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 36 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 36 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 37 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 37 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 38 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 38 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 39 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 39 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 40 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 40 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 41 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 41 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 42 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 42 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 43 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 43 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 44 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 44 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 45 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 45 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 46 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 46 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 47 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 47 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 48 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 48 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 49 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 49 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 50 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 50 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 51 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 51 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 52 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 52 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 53 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 53 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 54 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 54 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 55 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 55 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 56 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 56 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 57 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 57 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 58 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 58 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 59 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 59 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 60 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 60 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 61 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 61 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 62 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 62 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 63 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 63 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 64 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 64 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 65 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 65 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 66 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 66 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 67 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 67 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 68 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 68 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 69 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 69 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 70 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 70 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 71 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 71 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 72 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 72 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 73 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 73 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 74 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 74 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 75 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 75 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 76 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 76 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 77 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 77 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 78 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 78 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 79 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 79 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 80 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 80 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 81 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 81 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 82 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 82 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 83 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 83 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 84 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 84 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 85 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 85 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 86 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 86 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 87 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 87 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 88 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 88 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 89 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 89 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 90 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 90 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 91 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 91 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 92 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 92 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 93 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 93 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 94 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 94 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 95 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 95 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 96 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 96 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 97 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 97 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 98 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 98 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 99 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 99 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 100 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 100 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 101 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 101 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 102 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 102 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 103 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 103 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 104 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 104 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 105 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 105 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 106 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 106 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 107 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 107 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 108 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 108 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 109 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 109 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 110 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 110 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 111 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 111 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 112 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 112 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 113 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 113 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 114 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 114 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 115 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 115 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 116 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 116 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 117 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 117 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 118 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 118 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 119 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 119 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 120 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 120 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 121 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 121 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 122 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 122 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 123 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 123 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 124 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 124 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 125 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 125 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 126 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 126 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 127 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 127 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 128 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 128 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 129 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 129 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 130 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 130 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 131 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 131 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 132 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 132 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 133 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 133 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 134 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 134 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 135 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 135 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 136 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 136 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 137 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 137 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 138 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 138 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 139 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 139 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 140 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 140 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 141 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 141 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 142 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 142 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 143 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 143 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 144 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 144 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 145 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 145 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 146 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 146 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 147 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 147 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 148 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 148 : 0%| | 0/250 [00:00<?, ?it/s]
Training Epoch 149 : 0%| | 0/766 [00:00<?, ?it/s]
Testing Epoch 149 : 0%| | 0/250 [00:00<?, ?it/s]
In [ ]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_gen_losses, mode='lines', name='Train Generator Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_gen_losses, marker_symbol='star-diamond',
marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Generator Loss'))
fig.update_layout(
width=1000,
height=500,
title="Train vs. Test Generator Loss",
xaxis_title="Number of training examples seen",
yaxis_title="Adversarial + Content Loss"),
fig.show()
In [ ]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_disc_losses, mode='lines', name='Train Discriminator Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_disc_losses, marker_symbol='star-diamond',
marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Discriminator Loss'))
fig.update_layout(
width=1000,
height=500,
title="Train vs. Test Discriminator Loss",
xaxis_title="Number of training examples seen",
yaxis_title="Adversarial Loss"),
fig.show()
In [ ]:
import os
import random
import matplotlib.pyplot as plt
from PIL import Image
# Specify the directory where your images are located
image_dir = "outputs"
# List all files in the specified directory
image_files = os.listdir(image_dir)
# Randomly select 5 image files
selected_files = random.sample(image_files, 5)
# Create a subplot for each image
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
# Iterate over selected files and display them
for i, image_file in enumerate(selected_files):
image_path = os.path.join(image_dir, image_file)
image = Image.open(image_path)
axes[i].imshow(image)
axes[i].axis('off')
plt.show()